/*
 * Copyright 2020 zytech
 * All rights reserved.
 *
 *
 * SPDX-License-Identifier: BSD-3-Clause
 */
#include "erpc_server_tcp_transport.h"
#include <cstdio>
#include <errno.h>
#include <string>
#include <sys/types.h>
#include <memory>
#include <thread>
#include <functional>
#ifdef WINSOCK_USED
#include <WS2tcpip.h>
#define _CLOSE_SOCKET ::closesocket
#define _SD_BOTH_ SD_BOTH
#define _IS_INVALID_SOCKET_(s) (s == INVALID_SOCKET)
#define _STDCALL_MODIFIER __stdcall
#pragma comment(lib,"Ws2_32.lib ")
#elif defined(__unix)
#define _CLOSE_SOCKET ::close
#define _SD_BOTH_ SHUT_RDWR
#define _IS_INVALID_SOCKET_(s) (s < 0)
#define _STDCALL_MODIFIER
#include <err.h>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <fcntl.h>
#endif
#include "threadsafe_queue.h"

#ifndef function_traits_defined
#define function_traits_defined
template<typename T>
struct function_traits;
template<typename R, typename ...Args>
struct function_traits<R _STDCALL_MODIFIER (Args...)>
{
	static const size_t nargs = sizeof...(Args);
	// 返回类型
	typedef R result_type;

	// 输入参数类型,i为从0开始的参数类型索引
	template <size_t i>
	struct arg
	{
		typedef typename std::tuple_element<i, std::tuple<Args...>>::type type;
	};
};
#endif /** function_traits_defined */

using namespace erpc;
using namespace gdface;

// Set this to 1 to enable debug logging.
// TODO fix issue with the transport not working on Linux if debug logging is disabled.
//#define TCP_TRANSPORT_DEBUG_LOG (1)

#if TCP_TRANSPORT_DEBUG_LOG
#define TCP_DEBUG_PRINT(_fmt_, ...) printf(_fmt_, ##__VA_ARGS__)
#define TCP_DEBUG_ERR(_msg_) err(errno, _msg_)
#else
#define TCP_DEBUG_PRINT(_fmt_, ...)
#define TCP_DEBUG_ERR(_msg_)
#endif
#define TCP_LOG_ERR(_msg_)

thread_local _SOCK_TYPE_ ServerTCPTransport::m_socket = -1;
const uint32_t ServerTCPTransport::DEFAULT_TIMEOUT_MILLS = 4000;

using reqsocket_queue = threadsafe_queue<_SOCK_TYPE_>;
///  创建请求socket队列

static std::shared_ptr<void> create_reqsocket_queue(){
	return std::shared_ptr<void>(new reqsocket_queue(),
		[](void *p){
		if(p){
			TCP_DEBUG_PRINT("force close all connected socket");
			auto queue = (reqsocket_queue*)p;
			typename reqsocket_queue::value_type socket;
			// 强制关闭所有socket
			while(queue->try_pop(socket)){
				if(socket != -1){
					_CLOSE_SOCKET(socket);
				}
			}
			delete queue;
		}
	});
}
static reqsocket_queue* reqsocket_queue_of(ServerTCPTransport*transport){
	return (reqsocket_queue*)(transport->getReqSocketQueue().get());
}
////////////////////////////////////////////////////////////////////////////////
// Code
////////////////////////////////////////////////////////////////////////////////

ServerTCPTransport::ServerTCPTransport()
: m_host(NULL)
, m_port(0)
, m_reqSockets(create_reqsocket_queue())
, m_serverThread(serverThreadStub)
, m_runServer(true)
, m_sendTimeoutMills(DEFAULT_TIMEOUT_MILLS)
, m_recvTimeoutMills(DEFAULT_TIMEOUT_MILLS)
{
}

ServerTCPTransport::ServerTCPTransport(const char *host, uint16_t port)
: m_host(host)
, m_port(port)
, m_reqSockets(create_reqsocket_queue())
, m_serverThread(serverThreadStub)
, m_runServer(true)
, m_sendTimeoutMills(DEFAULT_TIMEOUT_MILLS)
, m_recvTimeoutMills(DEFAULT_TIMEOUT_MILLS)
{
}

ServerTCPTransport::~ServerTCPTransport(void) {

}

void ServerTCPTransport::configure(const char *host, uint16_t port)
{
    m_host = host;
    m_port = port;
}

void ServerTCPTransport::timeout(uint32_t sendTimeoutMills,uint32_t recvTimeoutMills)
{
	if(sendTimeoutMills > 0 )
	{
		m_sendTimeoutMills = sendTimeoutMills;
	}
	if(recvTimeoutMills > 0 )
	{
		m_recvTimeoutMills = recvTimeoutMills;
	}
}
erpc_status_t ServerTCPTransport::open(void)
{
	m_runServer = true;
	m_serverThread.start(this);
	return kErpcStatus_Success;
}

erpc_status_t ServerTCPTransport::close(void)
{
	TCP_DEBUG_PRINT("Server TCP transport closed\n");
	m_runServer = false;

	closeClientSocket();

    return kErpcStatus_Success;
}
erpc_status_t ServerTCPTransport::closeClientSocket(void)
{
    if (m_socket != -1)
    {
		TCP_DEBUG_PRINT("client socket %d closed\n",m_socket);
		::shutdown(m_socket, _SD_BOTH_);
		_CLOSE_SOCKET(m_socket);

		m_socket = -1;
    }

    return kErpcStatus_Success;
}

std::shared_ptr<void> ServerTCPTransport::getReqSocketQueue(void){
	return m_reqSockets;
}
void ServerTCPTransport::setSocket(_SOCK_TYPE_ socket) {
	m_socket = socket;
}
_SOCK_TYPE_ ServerTCPTransport::getSocket(void) {
	return m_socket;
}
erpc_status_t ServerTCPTransport::underlyingReceive(uint8_t *data, uint32_t size)
{
    // Block until we have a valid connection.
	if(m_socket <= 0){
       	//TCP_DEBUG_PRINT("in underlyingReceive reqsocket_queue %p empty:%d,size %ld\n",getReqSocketQueue().get(),reqsocket_queue_of(this)->empty(),reqsocket_queue_of(this)->size());
		m_socket = reqsocket_queue_of(this)->wait_and_pop();
		//TCP_DEBUG_PRINT("in underlyingReceive socket %d\n",m_socket);
	}
	decltype(recv(0, nullptr, 0, 0)) length = 0;
    // Loop until all requested data is received.
    while (size)
    {
#ifdef WINSOCK_USED
		length = ::recv(m_socket, (char*)data, (int)size,0);
#else
        length = ::recv(m_socket, data, size, MSG_WAITALL);
#endif
        if (length < 0)
        {
        	if (errno == EAGAIN){
        		TCP_DEBUG_PRINT("RECEIVE ERROR EAGAIN\n");
        	}else if (errno == EWOULDBLOCK){
        		TCP_DEBUG_PRINT("RECEIVE ERROR EWOULDBLOCK\n");
			}
        	closeClientSocket();
            return kErpcStatus_ReceiveFailed;
        }
        else
        {
            size -= length;
            data += length;
        }
    }

    return kErpcStatus_Success;
}

erpc_status_t ServerTCPTransport::underlyingSend(const uint8_t *data, uint32_t size)
{
    if (m_socket <= 0)
    {
        return kErpcStatus_SendFailed;
    }
    // Loop until all data is sent.
    while (size)
    {
#ifdef WINSOCK_USED
		auto result = ::send(m_socket, (char*)data, (int)size, 0);
#else
		auto result = ::send(m_socket, data, size, 0);
#endif        
        if (result >= 0)
        {
            size -= result;
            data += result;
        }
        else
        {
            if (errno == EPIPE)
            {
            	TCP_DEBUG_PRINT("SEND ERROR EPIPE\n");
                return kErpcStatus_ConnectionClosed;
            }
            return kErpcStatus_SendFailed;
        }
    }

    return kErpcStatus_Success;
}

void ServerTCPTransport::serverThread(void)
{
    TCP_DEBUG_PRINT("in server thread\n");
#ifdef WINSOCK_USED
	WORD wVersionRequested;
	WSADATA wsaData;
	int err;

	wVersionRequested = MAKEWORD(2, 2);

	err = WSAStartup(wVersionRequested, &wsaData);
	if (err != 0) {
		TCP_DEBUG_ERR("failed to WSAStartup");
		return;
	}
	if (LOBYTE(wsaData.wVersion) != 2 ||
		HIBYTE(wsaData.wVersion) != 2) {
		WSACleanup();
		TCP_DEBUG_ERR("Winsocket version mismatch");
		return;
	}
#endif
    // Create socket.
	_SOCK_TYPE_ serverSocket = socket(AF_INET, SOCK_STREAM, 0);
	if (_IS_INVALID_SOCKET_(serverSocket))
    {
        TCP_DEBUG_ERR("failed to create server socket");
        return;
    }

    // Fill in address struct.
    struct sockaddr_in serverAddress;
    memset(&serverAddress, 0, sizeof(serverAddress));
    serverAddress.sin_family = AF_INET;
    serverAddress.sin_addr.s_addr = INADDR_ANY; // htonl(local ? INADDR_LOOPBACK : INADDR_ANY);
    serverAddress.sin_port = htons(m_port);

    // Turn on reuse address option.
    int yes = 1;
    int result = setsockopt(serverSocket, SOL_SOCKET, SO_REUSEADDR, (char*)&yes, sizeof(yes));
    if (result < 0)
    {
        TCP_DEBUG_ERR("setsockopt SO_REUSEADDR failed");
        _CLOSE_SOCKET(serverSocket);
        return;
    }
    // Bind socket to address.
    result = bind(serverSocket, (struct sockaddr *)&serverAddress, sizeof(serverAddress));
    if (result < 0)
    {
        TCP_DEBUG_ERR("bind failed");
        _CLOSE_SOCKET(serverSocket);
        return;
    }

    // Listen for connections.
    result = listen(serverSocket, 1);
    if (result < 0)
    {
        TCP_DEBUG_ERR("listen failed");
		_CLOSE_SOCKET(serverSocket);
        return;
    }

    TCP_DEBUG_PRINT("Listening for connections\n");

    while (m_runServer)
    {
        struct sockaddr_in incomingAddress;
		using addrlen_type = std::remove_pointer<typename function_traits<decltype(accept)>::arg<2>::type>::type;
		addrlen_type incomingAddressLength = (addrlen_type)sizeof(struct sockaddr_in);
        int incomingSocket = accept(serverSocket, (struct sockaddr*)&incomingAddress, &incomingAddressLength);
		if (!_IS_INVALID_SOCKET_(incomingSocket))
        {
            char *client_ip = inet_ntoa(incomingAddress.sin_addr);
            int client_port = ntohs(incomingAddress.sin_port);
            TCP_DEBUG_PRINT("incoming %s:%d socket %d\n",client_ip,client_port,incomingSocket);
#ifndef WINSOCK_USED
			struct timeval timeout = {(__time_t)(m_recvTimeoutMills/1000),(__suseconds_t)((m_recvTimeoutMills % 1000) * 1000)};
			socklen_t timeout_len = sizeof(timeout);
			// receive timeout
			auto result = setsockopt(incomingSocket,SOL_SOCKET,SO_RCVTIMEO,(char *)&timeout,timeout_len);
			if (result < 0)
			{
				TCP_LOG_ERR("RECEIVE setsockopt SO_RCVTIMEO failed");
			}
			timeout = {(__time_t)(m_sendTimeoutMills/1000),(__suseconds_t)((m_sendTimeoutMills % 1000 ) * 1000)};
			// send timeout
			result = setsockopt(incomingSocket,SOL_SOCKET,SO_SNDTIMEO,(char *)&timeout,timeout_len);
			if (result < 0)
			{
				TCP_LOG_ERR("SEND setsockopt SO_SNDTIMEO failed");
			}
#endif
            // Successfully accepted a connection.
        	reqsocket_queue_of(this)->push(incomingSocket);
        	//TCP_DEBUG_PRINT("in serverThread reqsocket_queue %p empty:%d,size %ld\n",getReqSocketQueue().get(),reqsocket_queue_of(this)->empty(),reqsocket_queue_of(this)->size());
        }
        else
        {
        	TCP_LOG_ERR("accept failed");
        }
    }

	_CLOSE_SOCKET(serverSocket);
}

void ServerTCPTransport::serverThreadStub(void *arg)
{
    ServerTCPTransport *This = reinterpret_cast<ServerTCPTransport *>(arg);
    TCP_DEBUG_PRINT("in serverThreadStub (arg=%p)\n", arg);
    if (This)
    {
        This->serverThread();
    }
}
